from abc import ABC, abstractmethod
from typing import Tuple, Union

import numpy as np
import torch
from torch import Tensor

from matplotlib import pyplot as plt

from src.utils.noise_schedules import get_noise_schedule
from src.utils.trajectories import Trajectories

from GGNS.gradNS.nested_sampling import NestedSampler
from GGNS.gradNS.param import Param

FinalStateTensor = Tensor       # Shape = (batch_size,)
InitialStateTensor = Tensor     # Shape = (batch_size, 2)
TrajectoryStatesTensor = Tensor # Shape = (batch_size, traj_length + 1, 2)
ActionsTensor = Tensor          # Shape = (batch_size, traj_length, action_dim)
LossTensor = Tensor             # Shape = (1,)

LogProbsTensor = Tensor # Shape = (batch_size,)
FullLogProbsTensor = Tensor # Shape = (batch_size, traj_length)
CumulativeLogProbsTensor = Tensor # Shape = (batch_size, traj_length + 1)
MeanLossTensor = Tensor # Shape = (traj_length,)


class GFlowNet(ABC):
    """Base class for all GFN models."""

    def __init__(self, env, config, forward_model, backward_model):
        self.config = config
        self.env = env
        assert self.config["gfn"]["lr_schedule"] in ["none", "linear", "exponential"], "Invalid learning rate schedule"

        self.batch_size = self.config["batch_size"]
        self.trajectory_length = self.config["gfn"]["trajectory_length"]
        self.device = self.config["device"]

        self.optimizer = None
        self.scheduler = None

        self.off_policy_noise_schedule = get_noise_schedule(config)

        self.forward_model = forward_model.to(self.device)
        self.backward_model = backward_model.to(self.device)
        self.logZ = None
        self.logF_model = None 
        self.iteration = 0
        self.gradient_clipping = self.config["gfn"]["gradient_clipping"]
        self.clip_value = self.config["gfn"]["clip_value"]
        self.log_reward_clip_min = config["gfn"]["log_reward_clip_min"]  
        self.thompson_sampling = config["gfn"]["thompson_sampling"]
        self.local_search = config["gfn"]["local_search"]
        if self.local_search:
            self.local_search_K = config["gfn"]["local_search_K"]
        self.nested_sampling = config["gfn"]["nested_sampling"]
        self.generated_nested_sample_trajectories = False
        self.nested_samples_dataset = None
        assert not (self.thompson_sampling and self.config["metad"]["active"]), "Thompson Sampling and Metadynamics cannot be active at the same time"
        assert not (self.nested_sampling and self.config["metad"]["active"]), "Nested Sampling and Metadynamics cannot be active at the same time"
        assert not (self.nested_sampling and self.config["replay_buffer"]["active"]), "Nested Sampling and Replay Buffer cannot be active at the same time"


    def step(self, loss: LossTensor):
        """
        Performs a single optimization step on the loss and updates the learning rate schedule if a scheduler is used.

        Input:
            - loss (LossTensor): The current loss to take a gradient step on.
        """
        if loss is not None:
            self.optimizer.zero_grad()
            loss.backward()
            if self.gradient_clipping:
                torch.nn.utils.clip_grad_value_(self.forward_model.parameters(), self.clip_value)
                torch.nn.utils.clip_grad_value_(self.backward_model.parameters(), self.clip_value)
                if self.logF_model is not None:
                    torch.nn.utils.clip_grad_value_(self.logF_model.parameters(), self.clip_value)
                if self.logZ is not None:
                    torch.nn.utils.clip_grad_value_([self.logZ], self.clip_value)
            self.optimizer.step()

            if self.scheduler is not None:
                self.scheduler.step()
        else:
            # Skip this step if the loss is None (e.g. because no heads were selected during Thompson Sampling or no trajectories were sampled from the replay buffer)
            Warning("No loss to take a gradient step on")

    def _init_optimizer(self, tied: bool = False, include_logZ: bool = False, include_logF: bool = False) -> torch.optim.Optimizer:
        """
        Initialises the optimizer. 

        Input:
            - tied (bool): Specifies whether parameter weights are shared across the torsos of the models (logP_F, logP_B, logF).
            - include_logZ (bool): Specifies whether the logZ parameter is included in the optimizer (needed for TB loss).
            - include_logF (bool): Specifies whether the logF model is included in the optimizer (needed for DB and STB loss)
        """
        optimizer_params = []

        if tied:
            # Weights of forward and backward model are tied (except for last layer)
            optimizer_params.extend([
                {'params': self.forward_model.parameters(), 'lr': self.config["gfn"]["lr_model"]},
                {'params': self.backward_model.last_layer.parameters(), 'lr': self.config["gfn"]["lr_model"]}
            ])

        else:
            # Independent weights of the forward and backward model
            optimizer_params.extend([
                {'params': self.forward_model.parameters(), 'lr': self.config["gfn"]["lr_model"]},
                {'params': self.backward_model.parameters(), 'lr': self.config["gfn"]["lr_model"]}
            ])

        if include_logZ:
            optimizer_params.append({'params': [self.logZ], 'lr': self.config["gfn"]["lr_logz"]})

        if include_logF and tied:
            optimizer_params.append({'params': self.logF_model.last_layer.parameters(), 'lr': self.config["gfn"]["lr_model"]})
        elif include_logF and not tied:
            optimizer_params.append({'params': self.logF_model.parameters(), 'lr': self.config["gfn"]["lr_model"]})

        if self.config['gfn']['optimizer'] == "adam":
            optimizer = torch.optim.Adam(optimizer_params)
        elif self.config['gfn']['optimizer'] == "msgd":
            optimizer = torch.optim.SGD(optimizer_params)
        else:
            raise ValueError("Invalid optimizer type")

        return optimizer

    @abstractmethod
    def loss(self, trajs: Trajectories, head_index: int) -> LossTensor:
        """
        Computes the loss given the training objects.
        
        Returns:
            loss - torch.tensor of size (1), the loss of the sampled trajectories over the batch
        """

    def _init_scheduler(self, schedule_type: str = "none"):
        """
        Initialises the learning rate scheduler.
        """
        if schedule_type == "none":
            return None
        elif schedule_type == "linear":
            return torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor = 1, end_factor = 0, total_iters = self.config["n_iterations"])
        elif schedule_type == "exponential":
            return torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma = np.exp(- np.exp(1) / self.config["n_iterations"]))
        else:
            raise ValueError("Invalid schedule type")
    
    def get_loss(self, trajs: Trajectories) -> LossTensor:
        """Returns the loss for a batch of trajectories."""

        if trajs is None:
            return None

        if self.thompson_sampling:
            heads_to_include = self.forward_model.get_heads_to_include()
            if heads_to_include is None:
                raise ValueError("No heads were selected during Thompson Sampling")
            loss = 0
            for head_index in heads_to_include:
                loss_contribution = self.loss(trajs, head_index)
                loss += loss_contribution

            loss /= len(heads_to_include)

            self.iteration += 1

            return loss
        else:
            self.iteration += 1

            return self.loss(trajs)

    def run_batch(self) -> Tuple[Trajectories, LossTensor]:
        """
        Runs an on-policy batch.

        Trajectories are sampled according to the current forward policy.
        
        Returns:
            traj       - torch.tensor of size (batch_size, traj_length + 1, 2), the sampled trajectories
            loss       - torch.tensor of size (1), the loss of the sampled trajectories (mean over batch)
        """
        trajs = self.generate_trajs()
        loss = self.get_loss(trajs)

        return trajs, loss

    def run_metadynamics_batch(self, metad_sampler) -> Tuple[Trajectories, LossTensor]:
        """
        Runs a metadynamics batch.

        The final states of a batch of trajectories are sampled from the metadynamics process.
        Trajectories are then backward sampled from these final states to the initial state using the current backward policy.

        Returns:
            traj       - torch.tensor of size (batch_size, traj_length + 1, 2), the sampled trajectories
            loss       - torch.tensor of size (1), the loss of the sampled trajectories (mean over batch)
        """
        # Use metadynamics exploration to generate a batch of final states
        with torch.no_grad():
            y, log_rewards = metad_sampler.sample()

        # Use the backward policy to backward sample trajectories from the metadynamics final states
        states, actions = self.backward_sample_trajectories(y)

        # Initialise the Trajectories object and compute the logPF
        trajs = Trajectories(states, actions, log_rewards)
        loss = self.get_loss(trajs)

        return trajs, loss
    
    def run_nested_sampling_batch(self) -> Tuple[Trajectories, LossTensor]:
        """
        Runs a nested sampling batch, using the algorithm from the GGNS package.
        (https://arxiv.org/pdf/2312.03911)

        The first time this is called, the nested sampling algorithm is used to generate a set of final states (which is saved).
        Trajectories are then backward sampled from these final states to the initial state using the current backward policy.
        """
        if not self.generated_nested_sample_trajectories:
            # Generate the trajectories
            loglike = lambda x: self.env.log_reward(x)
            params = [Param(name=f'p{i}',
                prior_type='Uniform',
                prior=(self.env.lower_bound[i], self.env.upper_bound[i]),
                label=f'p_{i}')
            for i in range(self.env.dim)]
            ns = NestedSampler(loglike, params, device = self.device, clustering=False, verbose=False)
            ns.run()
            an = ns.convert_to_anesthetic()
            ns_samples = an.posterior_points()
            self.nested_samples_dataset = np.array(ns_samples[ns.paramnames])
            states, actions = self._generate_trajectories_from_final_states()
            self.generated_nested_sample_trajectories = True
        else:
            states, actions = self._generate_trajectories_from_final_states()

        log_rewards = self.env.log_reward(states[:, -1, :-1].squeeze())
        trajs = Trajectories(states, actions, log_rewards)
        loss = self.get_loss(trajs)

        return trajs, loss

    def run_replay_batch(self, replay_buffer) -> Tuple[Trajectories, LossTensor]:
        """
        Runs a replay batch.
        
        A batch of trajectories are sampled from the reply buffer using a biased sampling: a fraction
        alpha of the sampled trajectories are drawn from the top beta percentile of all trajectories in the 
        replay buffer, sorted by reward. The remaining 1-alpha fraction are sampled from the 1-beta percentiles.
        
        Returns:
            traj       - torch.tensor of size (batch_size, traj_length + 1, 2), the sampled trajectories
            loss       - torch.tensor of size (1), the loss of the sampled trajectories (mean over batch)
        """
        trajs = replay_buffer.sample()
        loss = self.get_loss(trajs)

        return trajs, loss
    
    def _sample_one_round_on_policy(self, states: TrajectoryStatesTensor, batch_size: int, head_index: Union[int, None]):
        """
        Sample batch_size number of on-policy trajectories.
        
        Inputs:
            states - torch.tensor of size (batch_size, traj_length + 1, 2), a pre-initialised trajectory tensor
            batch_size - int, the number of trajectories to sample
            head_index - int, the index of the head to use for sampling (None if Thompson Sampling is not used)

        Returns:
            states - torch.tensor of size (batch_size, traj_length + 1, 2), the sampled trajectories
        """
        x: InitialStateTensor = self._init_batch(batch_size=batch_size)

        for t in range(self.trajectory_length):
            forward_policy_dist = self.get_forward_policy_dist(x, head_index)
            action = forward_policy_dist.sample()
            new_x = self.env.step(x, action)
            states[:, t + 1, :] = new_x
            x = new_x

        return states
    
    def sample_on_policy(self, batch_size: int = 10_000) -> Trajectories:
        """
        Sample batch_size number of on-policy trajectories.

        Returns:
            traj - torch.tensor of size (batch_size, traj_length + 1, 2), the sampled trajectories.
        """

        with torch.no_grad():
            states: TrajectoryStatesTensor = torch.zeros((batch_size, self.trajectory_length + 1, self.env.dim + 1), device=self.device) 
            states[:, 0, :-1] = self.env.init_value
            if not self.thompson_sampling:
                states = self._sample_one_round_on_policy(states, batch_size, head_index=None)
            else:
                # When using Thompson Sampling, we sample from each head with equal frequency
                n_heads = self.forward_model.n_heads
                batch_size_per_head = batch_size // n_heads
                for head_index in range(n_heads):
                    states[head_index * batch_size_per_head: (head_index + 1) * batch_size_per_head] = self._sample_one_round_on_policy(states[head_index * batch_size_per_head: (head_index + 1) * batch_size_per_head], batch_size_per_head, head_index)

        trajs = Trajectories(states, None, None)

        return trajs
    
    def generate_trajs(self) -> Trajectories:
        """
        Generates a batch of trajectories and computes the logPF and batch_reward.

        The batch is either generated (1) entirely on-policy, (2) using noisy exploration, (3) Thompson Sampling, (4) epsilon noisy exploration.

        Returns:
            trajs       - Trajectories object, the sampled trajectories
        """
        states: TrajectoryStatesTensor = torch.zeros((self.batch_size, self.trajectory_length + 1, self.env.dim + 1), device=self.device) 
        actions: ActionsTensor = torch.zeros((self.batch_size, self.trajectory_length, self.env.action_dim), device=self.device) #TrajectoryTensor

        # Set initial state
        x: InitialStateTensor = self._init_batch() 
        states[:, 0, :] = x

        if self.thompson_sampling:
            head_index = self.forward_model.get_random_head()
        else:
            head_index = None

        # Generate full trajectories and compute logPFs
        if self.config["gfn"]["noise_exploration"]["active"]:
            states, actions = self._traj_using_off_policy_noise(x, states, actions, head_index)
        elif self.local_search:
            states, actions = self._traj_using_local_search(x, states, actions, head_index)
        else:
            states, actions = self._traj_onpolicy(x, states, actions, head_index)

        # Create Trajectories object
        log_rewards = self.env.log_reward(states[:, -1, :-1].squeeze())
        trajs = Trajectories(states, actions, log_rewards)

        return trajs

    def get_forward_policy_dist(self, states: TrajectoryStatesTensor, head_index: Union[None, int] = None) -> torch.distributions.Distribution:
        """
        Returns the forward policy dist for a batch of trajectories.
        
        Returns:
            forward_policy_dist - dimension (batch_size) normal distribution with mean and std given by the forward policy applied to the batch.
        """
        if len(states.shape) == 1:
            states = torch.unsqueeze(states, 0)

        pf_params = self.forward_model(self.env.featurisation(states), head_index)
        forward_policy_dist = self.env.get_policy_dist(pf_params)

        return forward_policy_dist
    
    def get_forward_exploration_policy_dist(self, states: TrajectoryStatesTensor, head_index: Union[None, int] = None):
        """
        Returns the noisy exploration policy dist applied to a batch of trajectories.

        Returns:
            exploration_dist - dimension (batch_size) normal distribution with mean and std given by a noisy forward policy applied to the batch.
        
        """
        off_policy_noise = self.off_policy_noise_schedule[self.iteration]
        pf_params = self.forward_model(self.env.featurisation(states), head_index)
        forward_exploration_policy_dist = self.env.get_exploration_dist(pf_params, off_policy_noise)

        return forward_exploration_policy_dist

    def get_backward_policy_dist(self, states: TrajectoryStatesTensor):
        """
        Returns the backward policy dist for a batch of trajectories.
        
        Returns:
            backward_policy_dist - dimension (batch_size) normal distribution with mean and std given by the backward policy applied to the batch.
        """
        pb_params = self.backward_model(self.env.featurisation(states))
        backward_policy_dist = self.env.get_policy_dist(pb_params)

        return backward_policy_dist
    
    def get_backward_exploration_policy_dist(self, states: TrajectoryStatesTensor):
        """
        Returns the backward policy dist and the noisy backward policy dist applied to a batch of trajectories.
        
        Returns:
            exploration_dist - dimension (batch_size) normal distribution with mean and std given by a noisy backward policy applied to the batch.
        """
        off_policy_noise = self.off_policy_noise_schedule[self.iteration]
        pb_params = self.backward_model(self.env.featurisation(states))
        backward_exploration_policy_dist = self.env.get_exploration_dist(pb_params, off_policy_noise)

        return backward_exploration_policy_dist
    
    def backward_sample_trajectories(self, final_states: FinalStateTensor, batch_size: Union[None,int] = None):
        """
        Generates trajectories with a given set of final states by backsampling with the backward policy.

        Returns:
            traj - torch.tensor of size (batch_size, traj_length + 1, env.dim + 1), the sampled trajectories.
            logPB - torch.tensor of size (batch_size), the cumulative logPB of each trajectory in the batch.
        """
        if batch_size is None: batch_size = self.batch_size
        # Unsqueeze final states if it is shape (batch_size,)
        if len(final_states.shape) == 1:
            final_states = final_states.unsqueeze(-1)
        actions: ActionsTensor = torch.zeros((batch_size, self.trajectory_length, self.env.action_dim), device=self.device) 
        states: TrajectoryStatesTensor = torch.zeros((batch_size, self.trajectory_length + 1, self.env.dim + 1), device=self.device) 
        x: InitialStateTensor = torch.cat([final_states, self.trajectory_length * torch.ones(batch_size, 1, device=self.device)], dim=1) 
        states[:, -1, :] = x
        for t in range(self.trajectory_length - 1, 0, -1):
            backward_exploration_policy_dist = self.get_backward_exploration_policy_dist(x)
            action = backward_exploration_policy_dist.sample()
            if action.dim() == 1:
                actions[:, t, :] = action.unsqueeze(dim=1)
            else:
                actions[:, t, :] = action
            
            new_x = self.env.backward_step(x, action)
            states[:, t, :] = new_x
            x = new_x
        # Set the initial state
        states[:, 0, :-1] = self.env.init_value
        actions[:, 0, :] = self.env.compute_initial_action(states[:, 1, :-1])

        return states, actions
    
    def _traj_using_off_policy_noise(self, x, states: TrajectoryStatesTensor, actions: ActionsTensor, head_index: Union[None, int] = None) -> Tuple[TrajectoryStatesTensor, ActionsTensor]:
        """
        Generates a batch of trajectories using off-policy noise and computes the cummulative logPF of those trajectories.

        Inputs:
            traj - torch.tensor of size (batch_size, traj_length + 1, 2), a pre-initialised trajectory tensor
            logPF - torch.tensor of size (batch_size), a pre-initialised cumulative logPF tensor

        Returns:
            traj - torch.tensor of size (batch_size, traj_length + 1, 2), the sampled trajectories
            logPF - torch.tensor of size (batch_size), the cummulative logPF of each trajectory in the batch
        """
        # Use off-policy noise exploration
        for t in range(self.trajectory_length):
            forward_exploration_policy_dist = self.get_forward_exploration_policy_dist(x, head_index)
            action = forward_exploration_policy_dist.sample()

            actions[:, t, :] = action

            new_x = self.env.step(x, action) 
            states[:, t + 1, :] = new_x
            x = new_x

        return states, actions
    
    def _traj_using_local_search(self, x, states: TrajectoryStatesTensor, actions: ActionsTensor, head_index: Union[int, None] = None) -> Tuple[TrajectoryStatesTensor, ActionsTensor]:
        """
        Sample batch_size number of on-policy trajectories with local search.
        (https://arxiv.org/pdf/2310.02710)
        
        Inputs:
            states - torch.tensor of size (batch_size, traj_length + 1, 2), a pre-initialised trajectory tensor
            batch_size - int, the number of trajectories to sample
            head_index - int, the index of the head to use for sampling (None if Thompson Sampling is not used)

        Returns:
            states - torch.tensor of size (batch_size, traj_length + 1, 2), the sampled trajectories
        """

        for t in range(self.trajectory_length):
            forward_policy_dist = self.get_forward_policy_dist(x, head_index)
            action = forward_policy_dist.sample()
            actions[:, t, :] = action
            new_x = self.env.step(x, action)
            states[:, t + 1, :] = new_x
            x = new_x

        # Backward sample the last K steps [tau_back = (x = sn, sn-1, ..., sn-K)]
        x = states[:, -1, :]
        for t in range(self.trajectory_length - 1, self.trajectory_length - 1 - self.local_search_K, -1):
            backward_policy_dist = self.get_backward_policy_dist(x)
            action = backward_policy_dist.sample()
            new_x = self.env.backward_step(states[:, t + 1, :], action)
            x = new_x

        # Reconstruct the last K steps [tau_recon = (sn-K, sn-K+1, ..., sn)]
        new_states = states.clone()
        new_actions = actions.clone()
        for t in range(self.trajectory_length - self.local_search_K, self.trajectory_length):
            forward_policy_dist = self.get_forward_policy_dist(x, head_index)
            action = forward_policy_dist.sample()
            new_actions[:, t, :] = action
            new_x = self.env.step(x, action)
            new_states[:, t + 1, :] = new_x
            x = new_x

            # Keep whichever of the two trajectories has the highest reward
            new_log_rewards = self.env.log_reward(new_states[:, -1, :-1].squeeze())   # reward of the reconstructed trajectory
            log_rewards = self.env.log_reward(states[:, -1, :-1].squeeze())           # reward of the original trajectory
            batch_ids_mask = new_log_rewards > log_rewards                            # where the reconstructed trajectory has a higher reward
            states[batch_ids_mask, :, :] = new_states[batch_ids_mask, :, :]           # replace the original trajectory only with the higher-reward reconstructed ones
            actions[batch_ids_mask, :, :] = new_actions[batch_ids_mask, :, :]

        return states, actions
    
    def _generate_trajectories_from_final_states(self):
        try:
            indices = np.random.choice(self.nested_samples_dataset.shape[0], self.batch_size, replace=False)
        except ValueError:
            print("Warning: Nested sampling generated fewer samples than the batch size. Generating more samples.")
            indices = np.random.choice(self.nested_samples_dataset.shape[0], self.batch_size, replace=True)
        final_states = torch.tensor(self.nested_samples_dataset[indices], device=self.device, dtype=torch.float32)
        states, actions = self.backward_sample_trajectories(final_states, batch_size=self.batch_size)

        return states, actions
    
    def _traj_onpolicy(self, x, states: TrajectoryStatesTensor, actions: ActionsTensor, head_index: Union[None, int] = None) -> Tuple[TrajectoryStatesTensor, ActionsTensor]:
        """
        Generates a batch of trajectories using on-policy sampling and computes the cummulative logPF of those trajectories.

        Inputs:
            traj - torch.tensor of size (batch_size, traj_length + 1, 2), a pre-initialised trajectory tensor
            logPF - torch.tensor of size (batch_size), a pre-initialised cumulative logPF tensor

        Returns:
            traj - torch.tensor of size (batch_size, traj_length + 1, 2), the sampled trajectories
            logPF - torch.tensor of size (batch_size), the cummulative logPF of each trajectory in the batch
        """
        for t in range(self.trajectory_length):
            forward_policy_dist = self.get_forward_policy_dist(x, head_index)
            action = forward_policy_dist.sample()
            actions[:, t, :] = action

            new_x = self.env.step(x, action)
            states[:, t + 1, :] = new_x
            x = new_x

        return states, actions

    def _init_batch(self, batch_size: Union[None, int] = None) -> InitialStateTensor:
        """Trajectory starts at state = (X_0, t=0)."""
        if batch_size is None:
            batch = torch.zeros((self.batch_size, self.env.dim + 1), device=self.device)
        else:
            batch = torch.zeros((batch_size, self.env.dim + 1), device=self.device)

        # Initialise the spatial state dimensions with the initial value
        batch[:, :-1] = self.env.init_value

        return batch
    
    def L1_error(self, samples: int = 1000, show_plot = False):
        """
        Computes the L1 error between the empirical and target distributions of the terminal states of the on-policy trajectories.
        """
        # Sample 'samples' number of on-policy trajectories and record their terminal states
        trajs = self.sample_on_policy(samples)

        # Compute the empirical distribution of the terminal states
        onpolicy_dist = self.env.get_onpolicy_dist(trajs.states)  
        # Compute the L1 error between the empirical and target distributions (rescaled so that it is between 0 and 1)
        L1_error = torch.sum(torch.abs(onpolicy_dist - self.env.target_density)) / 2

        if show_plot:
            plt.plot(onpolicy_dist, label="Empirical")
            plt.plot(self.env.target_density, label="Target")
            plt.show()

        return L1_error.item()
    